#!/usr/bin/python
# This file is part of KEmuFuzzer.
# 
# KEmuFuzzer is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
# 
# KEmuFuzzer is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along with
# KEmuFuzzer.  If not, see <http://www.gnu.org/licenses/>.

import os, sys
KEMUFUZZER_PATH = os.path.dirname(os.path.abspath(sys.argv[0]))
sys.path.append(KEMUFUZZER_PATH)

import subprocess, signal, tempfile, hashlib
import time, gzip, traceback, shutil
import hashlib
import vmware_driver 
from x86_cpustate import *
from tc import TestCase, tcgen

KERNEL = os.path.join(KEMUFUZZER_PATH, "kernel/kernel")
FLOPPY = os.path.join(KEMUFUZZER_PATH, "kernel/floppy.img")
NULL = open("/dev/null", "w")
TIMEOUT = 10
VMWARE_TIMEOUT = TIMEOUT*4

class ExecutionModule():
    def __init__(self, floppy = None, testcase = None, gui = False, kernel_md5 = None, testcase_md5 = None):
        self.__process = None
        self.env = os.environ
        self.executionmodule = None
        self.cwd = None
        self.floppy = floppy
        self.testcase = testcase
        self.gui = gui
        self.env["KEMUFUZZER_KERNEL_VERSION"] = "protected mode + paging"
        self.env["KEMUFUZZER_KERNEL_CHECKSUM"] = kernel_md5
        if testcase_md5 is not None:
            self.env["KEMUFUZZER_TESTCASE_CHECKSUM"] = testcase_md5

    def __del__(self):
        self.kill()

    def run(self):
        t0 = time.time()

        if not testcase:
            print "[*] Running %s with kernel %s" % (self.executionmodule, self.floppy)
        else:
            print "[*] Running %s with testcase %s" % (self.executionmodule, self.testcase)

        if os.path.isfile(self.poststate):
            os.unlink(self.poststate)

        self.__process = subprocess.Popen(self.cmdline.split(), env = self.env, cwd = self.cwd, stdin = NULL, stderr = subprocess.STDOUT)
        while not os.path.isfile(self.poststate):
            t1 = time.time()
            if (t1 > t0 + TIMEOUT): 
                raise Timeout()
            r = self.__process.poll()
            if (r is not None and r != 0):
                raise Crash()
            time.sleep(0.01)

        return

    def kill(self):
        if self.__process is not None:
            print "[*] Killing %s (#%d)" % (self.executionmodule, self.__process.pid)
            try:
                os.kill(self.__process.pid, signal.SIGKILL)
            except Exception:
                pass
            self.__process = None

class KVMExecutionModule(ExecutionModule):
    KVM_EXE = os.path.join(KEMUFUZZER_PATH, "./kvm")

    def __init__(self, prestate, poststate, floppy, kernel, testcase, gui, kernel_md5, testcase_md5):
        ExecutionModule.__init__(self, None, testcase, gui, kernel_md5, testcase_md5)

        self.executionmodule = "KVM"

        try:
            open("/dev/kvm")
        except Exception:
            assert False, "[!] Unable to open /dev/kvm"

        self.prestate = prestate
        self.poststate = poststate
        del self.env["KEMUFUZZER_KERNEL_VERSION"] 
        del self.env["KEMUFUZZER_KERNEL_CHECKSUM"] 
        if "KEMUFUZZER_TESTCASE_CHECKSUM" in self.env:
            del self.env["KEMUFUZZER_TESTCASE_CHECKSUM"] 
        assert os.path.isfile(os.path.realpath(prestate)), "[!] Invalid input state '%s'" % prestate
        self.cmdline = "%s %s %s" % (self.KVM_EXE, prestate, poststate)
        assert os.path.isdir(os.path.dirname(poststate)), "[!] Invalid output directory '%s'" % os.path.dirname(poststate)

class QEMUExecutionModule(ExecutionModule):
    QEMU_EXE = "qemu"

    def __init__(self, prestate, poststate, floppy, kernel, testcase, gui, kernel_md5, testcase_md5):
        ExecutionModule.__init__(self, floppy, testcase, gui, kernel_md5, testcase_md5)

        self.executionmodule = "QEMU"

        self.prestate = prestate
        self.poststate = poststate
        self.env["KEMUFUZZER_PRE_STATE"] = prestate
        self.env["KEMUFUZZER_POST_STATE"] = poststate
        if self.gui: nographic = ""
        else: nographic = "-nographic"
        self.cmdline = "%s %s -m 4 -fda %s" % (self.QEMU_EXE, nographic, self.floppy)
        assert os.path.isdir(os.path.dirname(prestate)), "[!] Invalid output directory '%s'" % os.path.dirname(prestate)
        assert os.path.isdir(os.path.dirname(poststate)), "[!] Invalid output directory '%s'" % os.path.dirname(poststate)

class VMWAREExecutionModule(ExecutionModule):
    def __init__(self, prestate, poststate, floppy, kernel, testcase, gui, kernel_md5, testcase_md5):
        ExecutionModule.__init__(self, floppy, testcase, gui, kernel_md5, testcase_md5)

        self.floppy = floppy
        self.executionmodule = "VMWARE"

        self.prestate = prestate
        self.poststate = poststate

        self.gui = gui

        assert os.path.isdir(os.path.dirname(prestate)), "[!] Invalid output directory '%s'" % os.path.dirname(prestate)
        assert os.path.isdir(os.path.dirname(poststate)), "[!] Invalid output directory '%s'" % os.path.dirname(poststate)

    def run(self):
        for k,v in self.env.iteritems():
            if k.startswith("KEMUFUZZER"):
                os.environ[k] =v

        os.environ["KEMUFUZZER_PRE_STATE"] = self.prestate
        os.environ["KEMUFUZZER_POST_STATE"] = self.poststate

        vmware_driver.run(disk = self.floppy, kernel = kernel, gui = self.gui, timeout = VMWARE_TIMEOUT)

    def kill(self):
        pass

class BOCHSExecutionModule(ExecutionModule):
    BOCHS_EXE = "bochs"

    def __init__(self, prestate, poststate, floppy, kernel, testcase, gui, kernel_md5, testcase_md5):
        ExecutionModule.__init__(self, floppy, testcase, gui, kernel_md5, testcase_md5)

        self.executionmodule = "BOCHS"

        self.prestate = prestate
        self.poststate = poststate
        self.env["KEMUFUZZER_PRE_STATE"] = prestate
        self.env["KEMUFUZZER_POST_STATE"] = poststate

        # create temporary configuration file
        if self.gui:
            display = "x"
        else:
            display = "nogui"

        self.configfile = tempfile.mktemp(prefix = "kemufuzzer-bochsrc-")
        f = open(self.configfile, "w")
        f.write("""\
romimage: file=%s
vgaromimage: file=%s
log: -
panic: action=fatal
megs: 4
floppya: 1_44=%s, status=inserted
display_library: %s\n""" % ("/usr/share/bochs/BIOS-bochs-latest", 
                            "/usr/share/vgabios/vgabios.bin", 
                            self.floppy, display))
        f.close()

        self.cmdline = "%s -q -f %s" % (self.BOCHS_EXE, self.configfile)
        assert os.path.isdir(os.path.dirname(prestate)), "[!] Invalid output directory '%s'" % os.path.dirname(prestate)
        assert os.path.isdir(os.path.dirname(poststate)), "[!] Invalid output directory '%s'" % os.path.dirname(poststate)


    def __del__(self):
        os.unlink(self.configfile)

class VBOXExecutionModule(ExecutionModule):
    VBOX_EXE_NOGUI = "VBoxHeadless"
    VBOX_EXE_GUI = "VBoxSDL"
    VBOX_VM = "KEmuFuzzer"

    def __init__(self, prestate, poststate, floppy, kernel, testcase, gui, kernel_md5, testcase_md5):
        ExecutionModule.__init__(self, floppy, testcase, gui, kernel_md5, testcase_md5)
 
        self.executionmodule = "VBOX"

        self.prestate = tempfile.mktemp(prefix = "kemufuzzer-")
        self.poststate = tempfile.mktemp(prefix = "kemufuzzer-")
        self.prestate_real = prestate
        self.poststate_real = poststate
        self.env["KEMUFUZZER_PRE_STATE"] = self.prestate
        self.env["KEMUFUZZER_POST_STATE"] = self.poststate
        # Change cwd to log on a local file system
        od = os.path.abspath(os.curdir)
        os.chdir("/tmp")
        cmd = "VBoxManage modifyvm %s --floppy %s" % (self.VBOX_VM, floppy)
        subprocess.call(cmd.split(), stdout = NULL)
        os.chdir(od)        
       
        if self.gui:
            self.cmdline = "%s --startvm %s" % (self.VBOX_EXE_GUI, self.VBOX_VM)
        else:
            self.cmdline = "%s --startvm %s" % (self.VBOX_EXE_NOGUI, self.VBOX_VM)
        assert os.path.isdir(os.path.dirname(prestate)), "[!] Invalid output directory '%s'" % os.path.dirname(prestate)
        assert os.path.isdir(os.path.dirname(poststate)), "[!] Invalid output directory '%s'" % os.path.dirname(poststate)

    def run(self):
        # Use temporary file because output is not gzipped
        od = os.path.abspath(os.curdir)
        os.chdir("/tmp")
        try:
            ExecutionModule.run(self)
            os.chdir(od)    
        except Exception, e:
            os.chdir(od)    
            raise e

        # Gzip output
        gzip.open(self.prestate_real, "w").write(open(self.prestate).read())
        gzip.open(self.poststate_real, "w").write(open(self.poststate).read())
        os.unlink(self.prestate)
        os.unlink(self.poststate)
        self.prestate = self.prestate_real
        self.poststate = self.poststate_real

        # Delete floppies
        od = os.path.abspath(os.curdir)
        os.chdir("/tmp")
        os.system("for f in $(VBoxManage list floppies 2> /dev/null | grep Path: | cut -c 13-); do VBoxManage closemedium floppy $f 2>&1 > /dev/null ; done")
        os.chdir(od)    
        
EM = {
    "QEMU"     : QEMUExecutionModule,
    "KVM"      : KVMExecutionModule,
    "BOCHS"    : BOCHSExecutionModule,
    "VBOX"     : VBOXExecutionModule,
    "VMWARE"   : VMWAREExecutionModule,
    }

def create_dummy_state(f, typ, emu, kernel_md5, testcase_md5, kernel_version):
    mm = {}
    for k,v in EM.iteritems():
        for i, k2 in EMULATORS.iteritems():
            if k == k2:
                mm[v] = i
                break
    assert emu.__class__ in mm, "[!] Invalid emulator '%s'" % emu

    emu_id = mm[emu.__class__]

    hdr = header_t()
    hdr.magic      = CPU_STATE_MAGIC
    hdr.version    = CPU_STATE_VERSION
    hdr.emulator   = emu_id
    hdr.kernel_version = kernel_version
    hdr.kernel_checksum = kernel_md5
    hdr.testcase_checksum = testcase_md5
    hdr.type       = typ
    hdr.cpusno     = 0
    hdr.mem_size   = 0
    hdr.ioports[0] = KEMUFUZZER_HYPERCALL_START_TESTCASE
    hdr.ioports[1] = KEMUFUZZER_HYPERCALL_STOP_TESTCASE

    s = string_at(byref(hdr), sizeof(hdr))
    f = gzip.open(f, 'w')
    f.write(s)
    f.close()

if __name__ == "__main__":
    t0 = time.time()
    
    print " ".join(sys.argv)
    args = {"kernel" : KERNEL,
            "floppy" : FLOPPY,
            "testcase" : None,
            "pre" : None,
            "post" : None,
            "emu" : None,
            "outdir" : None,
            "kerneldir" : None,
            "nodel": False,
            "nokill" : False,
            "gui" : False}

    for i in sys.argv[1:]:
        a, v = i.split(":")
        args[a] = v

    kernel = args["kernel"]
    floppy = args["floppy"]
    emu = args["emu"]
    if emu == "ORACLE": emu = "KVM"
    prestate = args["pre"]
    poststate = args["post"]
    outdir = args["outdir"]
    testcase = args["testcase"]
    kerneldir = args["kerneldir"]
    gui = bool(args["gui"])
    nokill = bool(args["nokill"])
    nodel = bool(args["nodel"])

    assert testcase or (not testcase and os.path.isfile(os.path.realpath(kernel))), \
        "[!] Kernel image (%s) is missing" % kernel
    assert not testcase or (testcase.endswith(".testcase") and os.path.isfile(os.path.realpath(testcase))), \
                                "[!] Invalid testcase (%s)" % testcase
    assert emu and emu in EM    
    assert (emu == "KVM" and prestate) or (emu != "KVM")
    assert kernel and floppy
    assert os.path.isfile(kernel) and os.path.isfile(floppy)

    kernel = os.path.abspath(kernel)
    floppy = os.path.abspath(floppy)
    kernel_md5 = hashlib.md5(open(kernel).read()).hexdigest()
    testcase_md5 = None
    kernel_ver = "protected mode "

    # Make a backup of the kernel in case it is needed in the future
    if kerneldir:
        assert os.path.isdir(kerneldir)
        if not os.path.isfile(os.path.join(kerneldir, kernel_md5)):
            open(os.path.join(kerneldir, kernel_md5), "w").write(open(kernel).read())

    patched_floppy = floppy
    if testcase:
        try:
            testcase_md5 = hashlib.md5(open(testcase).read()).hexdigest()
            patched_floppy = tcgen(testcase, kernel, floppy)
        except Exception:
            traceback.print_exc()
            print "[!] Invalid test-case (%s)" % testcase
            exit(1)

    assert (prestate and poststate) or (outdir and testcase) or (prestate and outdir and emu == "KVM")

    if not prestate:
        prestate = os.path.join(outdir, os.path.basename(testcase).replace(".testcase", ".pre"))

    if not poststate:
        if emu != "KVM":
            poststate = os.path.join(outdir, os.path.basename(testcase).replace(".testcase", ".post"))
        else:
            if testcase:
                f = os.path.basename(testcase).replace(".testcase", ".post")
            else:
                f = os.path.basename(prestate).replace(".pre", ".post")
            poststate = os.path.join(outdir, os.path.split(os.path.dirname(prestate))[1].lower(), f)

    if emu == "KVM":
        # update test-case checksum and kernel-md5 (we need them in case of a crash)
        try:
            pre = X86Dump(gzip.open(prestate).read())
        except IOError:
            pre = XX86Dump(open(prestate).read())
        kernel_md5 = pre.hdr.kernel_checksum
        kernel_ver = pre.hdr.kernel_version
        testcase_md5 = pre.hdr.testcase_checksum

    if not os.path.isfile(prestate) or emu != "KVM":
        tmp_prestate = tempfile.mktemp(prefix = "kemufuzzer-pre-")
    else:
        # Do not use temporary files if prestate exists
        tmp_prestate = prestate

    tmp_poststate = tempfile.mktemp(prefix = "kemufuzzer-post-")

    emu = EM[emu](tmp_prestate, tmp_poststate, patched_floppy, kernel, testcase, gui, kernel_md5, testcase_md5)

    try:
        emu.run()
        assert os.path.isfile(os.path.realpath(tmp_prestate)) and os.path.isfile(tmp_poststate), "[!] Error while running"
        print "[*] Wrote dump to %s" % ((tmp_prestate, tmp_poststate),)
        r = 0
    except Timeout:
        print "[!] Execution timeout"
        # Create fake states
        if not os.path.isfile(os.path.realpath(tmp_prestate)):
            create_dummy_state(os.path.realpath(tmp_prestate), PRE_TESTCASE | TIMEOUT_TESTCASE, emu, 
                               kernel_md5, testcase_md5, kernel_ver)
        if not os.path.isfile(os.path.realpath(tmp_poststate)):
            create_dummy_state(os.path.realpath(tmp_poststate), POST_TESTCASE | TIMEOUT_TESTCASE, emu,
                               kernel_md5, testcase_md5, kernel_ver)
        r = 1
    except Crash:
        print "[!] Crash"
        # Create fake states
        if not os.path.isfile(os.path.realpath(tmp_prestate)):
            create_dummy_state(os.path.realpath(tmp_prestate), PRE_TESTCASE | CRASH_TESTCASE, emu,
                               kernel_md5, testcase_md5, kernel_ver)
        if not os.path.isfile(os.path.realpath(tmp_poststate)):
            create_dummy_state(os.path.realpath(tmp_poststate), POST_TESTCASE | CRASH_TESTCASE, emu,
                               kernel_md5, testcase_md5, kernel_ver)

        r = 1

    if nokill:
        print "[*] Press RETURN to kill the emulator...",
        raw_input()

    emu.kill()

    if patched_floppy != floppy and not nodel:
        print "[*] Deleting temporary floppy image '%s'" % patched_floppy
        os.unlink(patched_floppy)

    t1 = time.time()
    print "[%s] Completed in %.3f secs." % ("*!"[r], t1 - t0)
    
    if tmp_prestate != prestate and os.path.isfile(os.path.realpath(tmp_prestate)):
        print "[*] Flushing state '%s' -> '%s'" % (tmp_prestate, prestate)
        shutil.move(tmp_prestate, prestate)

    if os.path.isfile(tmp_poststate):
        print "[*] Flushing state '%s' -> '%s'" % (tmp_poststate, poststate)
        shutil.move(tmp_poststate, poststate)

    t1 = time.time()
    print "[%s] Completed in %.3f secs." % ("*!"[r], t1 - t0)
    
    exit(0)
